# -------------------------------------------------------------------
# Maps of the different variables 
# -------------------------------------------------------------------
library(pacman)
p_load(
  data.table, here, ggplot2, fixest, magrittr, dplyr, ggExtra, 
  haven, purrr, sf, tigris, tidyverse, janitor, yardstick
)

options(tigris_use_cache = TRUE)
theme_set(theme_minimal())

# Predictions from structural code
merged_dt = 
  cbind(
    fread(here("Data/CleanedData/MergedData.csv")),
    read_dta(here("Data/CleanedData/BI_tCollPercPolnm0Base.dta")),
    read_dta(here("Data/CleanedData/HH_tract.dta"))
  ) |>
  merge(
    fread(here("Data/census-regions.csv")),
    by.x = "state", 
    by.y = "State Code",
    all.x = T
  )

# Calculating log installs per capita
merged_dt[,':='( 
  fips = str_pad(fips, width = 11, side = "left", pad = "0"),
  pred_linstall_pop = log(BI_t_pred/population),
  prob_b_sim = BI_t_pred/HH_tract,
  prob_b_data = BI_t/HH_tract,
  l_prob_b_sim = log(BI_t_pred/HH_tract),
  l_prob_b_data = log(BI_t/HH_tract)
)]

# Raw google and deep solar data
google_dt = fread(here("Data/google-project-sunroof/project-sunroof-census_tract.csv")) %>% 
  .[,region_name := as.character(region_name)]
deep_solar_dt = fread(here("Data/deep-solar/deepsolar_tract.csv")) %>% 
  .[,fips := str_pad(fips, width = 11, side = "left", pad = "0")]

# Now getting shape files
states_sf = 
  states(year = 2000, cb = TRUE) |>
  filter( # Limiting to continental US
    !(STATE %in% c("02","15","60","66","69","72","78"))
  ) |>
  clean_names()

tract_sf =
  map_dfr(
    merged_dt$state_fips |> unique() |> str_pad(width = 2, side = "left", pad = "0"),
    \(st){tracts(cb = TRUE, year = 2015, state = st)}
  ) |> clean_names()

county_sf =
  map_dfr(
    states_sf$state,
    counties,
    year = 2010,
    cb = TRUE
  ) |> 
  clean_names() |>
  st_simplify(dTolerance = 0.001) |>
  mutate(
    GEOID = paste0(statefp,countyfp)
  ) |>
  st_transform(crs = 2163) 

# Merging data w/tract and exporting --------------------------------
merged_sf = 
  left_join(
    tract_sf, 
    merged_dt[,.(
      fips, 
      BI_t, BI_t_pred, HH_tract,
      prob_b_sim, prob_b_data, 
      l_prob_b_sim, l_prob_b_data, 
      A_t, CollPerc_t, Pol_t, 
      ptAt, tot_ben
    )],
    by = c("geoid"="fips")
  ) |>
  left_join(
    deep_solar_dt[,.(
      fips, county,
      solar_system_count, 
      solar_system_count_residential,
      ss_pc = solar_system_count/population,
      ssr_pc = solar_system_count_residential/population,
      total_panel_area,
      total_panel_area_residential,
      daily_solar_radiation,
      voting_2016_dem_percentage,
      education_college_rate,
      population,
      avg_electricity_retail_rate,
      pr_rad = avg_electricity_retail_rate*daily_solar_radiation
    )],
    by = c("geoid"="fips")
  ) |>
  st_transform(crs = 2163)

# Exporting so we can use QGIS for tract level things 
st_write(
  merged_sf,
  dsn = here("Data/CleanedData/merged-shape/merged-sf.shp"),
  append = FALSE
)

# Checking that merge worked
merged_tract_dt = data.table(merged_sf) %>% .[,geometry:=NULL]
merged_tract_dt[!is.na(BI_t),.N] == merged_dt[,.N]

# County level plots for merged_dt data -----------------------------
merged_dt[,county_fips := str_sub(fips, 1, 5)]

# Aggregating to the county level
merged_county_dt = merged_dt[,c(
  lapply(.SD, weighted.mean, w = HH_tract, na.rm = TRUE),
  tot_pop = sum(population, na.rm = TRUE),
  BI_t = sum(BI_t, na.rm = TRUE),
  BI_t_pred = sum(BI_t_pred, na.rm = TRUE),
  HH_tract = sum(HH_tract, na.rm = TRUE)),
  by = county_fips,
  .SDcols = c( 
    "A_t","CollPerc_t","Pol_t",
    "ptAt","tot_ben"
  )
]

merged_county_dt[,':='(
  pred_linstall_pop = log(BI_t_pred/tot_pop),
  prob_b_sim = BI_t_pred/HH_tract,
  prob_b_data = BI_t/HH_tract,
  l_prob_b_sim = log(BI_t_pred/HH_tract),
  l_prob_b_data = log(BI_t/HH_tract)
)]

# Merging with sf 
merged_county_sf =
  left_join(
    county_sf, 
    merged_county_dt[prob_b_data < 1], 
    by = c("GEOID"="county_fips")
  )

l_prob_b_sim_cmap = 
  ggplot() + 
  geom_sf(data = merged_county_sf, aes(fill = prob_b_sim, color = prob_b_sim)) + 
  geom_sf(data = states_sf, color = "black", fill = NA, size = 0.25) + 
  scale_fill_viridis_c("Installs", trans = "log", labels = scales::percent)+
  scale_color_viridis_c("Installs", trans = "log", labels = scales::percent) +
  labs(
    title = "Simulated", 
    caption = "Installs is the log(installs/houses), with installs and houses summed over all tracts in a county."
  )
l_prob_b_data_cmap = 
  ggplot() + 
  geom_sf(data = merged_county_sf,aes(fill = prob_b_data, color = prob_b_data)) + 
  geom_sf(data = states_sf, color = "black", fill = NA, size = 0.25) + 
  scale_fill_viridis_c("Installs", trans = "log", labels = scales::percent)+
  scale_color_viridis_c("Installs", trans = "log", labels = scales::percent) +
  labs(title = "Actual")
  
ggsave(
  plot = gridExtra::grid.arrange(l_prob_b_data_cmap,l_prob_b_sim_cmap, ncol = 2),
  filename = here("figures/maps/prob-b-county.jpeg"),
  width = 12, height = 6
)


# -------------------------------------------------------------------
# Aggregating to the state level
merged_state_dt = 
  merged_dt[,.(
    BI_t = sum(BI_t, na.rm = TRUE),
    BI_t_pred = sum(BI_t_pred, na.rm = TRUE),
    HH_tract = sum(HH_tract, na.rm = TRUE),
    population = sum(population, na.rm = TRUE)
  ),
  by = .(state = str_sub(county_fips, 1,2), state_name)
  ]

merged_state_dt[,':='(
  pred_linstall_pop = log(BI_t_pred/population),
  prob_b_sim = BI_t_pred/HH_tract,
  prob_b_data = BI_t/HH_tract,
  l_prob_b_sim = log(BI_t_pred/HH_tract),
  l_prob_b_data = log(BI_t/HH_tract)
)]


state_install_sf =
  merge(
    states_sf, 
    merged_state_dt,
    by = "state"
  )|>
  merge(
    merged_dt |>
      group_by(state_fips) |>
      rsq(
        truth = prob_b_data, 
        estimate = prob_b_sim, 
        na_rm = TRUE
      ) |>
      mutate(
        state = str_pad(state_fips, side = "left",width =2, pad = "0"),
        rsq = .estimate
      ) |> select(state, rsq),
    by = "state"
  ) |>
  st_transform(crs = 2163)


state_install_map = 
  ggplot(state_install_sf, aes(fill = prob_b_data)) + 
  geom_sf(color = NA) + 
  scale_fill_viridis_c(name = "Installs per House", trans = "log", labels = scales::percent) + 
  ggtitle("Actual Installs per House")
ggsave(
  plot = state_install_map,
  filename = here("figures/maps/state_install_map.jpeg"),
  width = 10, height = 7
)

state_pred_install_map = 
  ggplot(state_install_sf, aes(fill = prob_b_sim)) + 
  geom_sf(color = NA) + 
  scale_fill_viridis_c(name = "Installs per House", trans = "log", labels = scales::percent) + 
  ggtitle("Predicted Installs per House")
ggsave(
  plot = state_pred_install_map,
  filename = here("figures/maps/state_pred_install_map.jpeg"),
  width = 10, height = 7
)

state_rsq_map = 
  ggplot(state_install_sf, aes(fill = rsq)) + 
  geom_sf(color = NA) + 
  scale_fill_viridis_c(name = "R2 (Log scale)", trans = "log", labels = scales::percent) + 
  ggtitle("R-Squared by State")
ggsave(
  plot = state_rsq_map,
  filename = here("figures/maps/state_rsq_map.jpeg"),
  width = 10, height = 7
)

